Triton入门Demo

原版:

srush/Triton-Puzzles: Puzzles for learning Triton

缩减版:

SiriusNEO/Triton-Puzzles-Lite: Puzzles for learning Triton, play it with minimal environment configuration!

Demo

import triton.language as tl

启动Triton函数

Triton 的网格配置由 3 个维度组成:(num_programs_x, num_programs_y, num_programs_z),它们分别控制线程块在 X、Y、Z轴上的分布。具体来说:

  • num_programs_x :
    • 线程块在 X 轴方向的数量。
    • Triton 内核的每个线程块负责计算一个子区域的数据。
  • num_programs_y:
    • 线程块在 Y 轴方向的数量。
    • 如果需要跨多个维度(如 2D 矩阵),可以沿 Y 轴扩展线程块。
  • num_programs_z :
    • 线程块在 Z 轴方向的数量。
    • 通常在 3D 数据(如体积数据或多批次数据)处理中使用。

读取数据

tl.load 是 Triton 的重要函数,用于从 GPU 内存中高效读取数据,同时支持掩码。

tl.load(pointer, mask=None, other=None)

  • pointer:加载数据的内存地址
  • mask:掩码
  • other:mask=False的选项默认值
@triton.jit
def demo1(x_ptr):
    # range [0 1 2 3 4 5 6 7]
    # mask = range < 5 = [1 1 1 1 1 0 0 0]
    range = tl.arange(0, 8)
    x = tl.load(x_ptr + range, range < 5, 0)

def run_demo1():
    demo1[(1, 1, 1)](torch.ones(4, 3))
    print_end_line()

多维

None的意思是该维度增加一个维度为1的值

@triton.jit
def demo2(x_ptr):
    i_range = tl.arange(0, 8)[:, None]
    # i_range, 一列0到7
    j_range = tl.arange(0, 4)[None, :]
    # j_range, 一行0到4
    range = i_range * 4 + j_range
    # range [[ 0  1  2  3]
    #        [ 4  5  6  7]
    #        [ 8  9 10 11]
    #        [12 13 14 15]
    #        [16 17 18 19]
    #        [20 21 22 23]
    #        [24 25 26 27]
    #        [28 29 30 31]]
    x = tl.load(x_ptr + range, (i_range < 4) & (j_range < 3), 0)
    # x [[1. 1. 1. 0.]
    #    [1. 1. 1. 0.]
    #    [1. 1. 1. 0.]
    #    [1. 1. 1. 0.]
    #    [0. 0. 0. 0.]
    #    [0. 0. 0. 0.]
    #    [0. 0. 0. 0.]
     #    [0. 0. 0. 0.]]

pointer是地址的矩阵,掩码是True或False。


写入数据

tl.store(pointer, value, mask=None)

  • pointer:写入数据的目标内存地址
  • value:写入的数据
  • mask=None:掩码
@triton.jit
def demo3(z_ptr):
    range = tl.arange(0, 8)
    z = tl.store(z_ptr + range, 10, range < 5)

并行处理

@triton.jit
def demo4(x_ptr):
    pid = tl.program_id(0)
    range = tl.arange(0, 8) + pid * 8
    x = tl.load(x_ptr + range, range < 20)

def run_demo4():
    x = torch.ones(2, 4, 4)
    demo4[(3, 1, 1)](x)

pid就是grid中的x。

results matching ""

    No results matching ""